
<!DOCTYPE html>

<html lang="zh">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.17.1: http://docutils.sourceforge.net/" />

    <title>Swin Transformer解读 &#8212; 深入浅出PyTorch</title>
    
  <!-- Loaded before other Sphinx assets -->
  <link href="../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">

    
  <link rel="stylesheet"
    href="../_static/vendor/fontawesome/5.13.0/css/all.min.css">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
  <link rel="preload" as="font" type="font/woff2" crossorigin
    href="../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">

    <link rel="stylesheet" type="text/css" href="../_static/pygments.css" />
    <link rel="stylesheet" href="../_static/styles/sphinx-book-theme.css?digest=62ba249389abaaa9ffc34bf36a076bdc1d65ee18" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../_static/togglebutton.css" />
    <link rel="stylesheet" type="text/css" href="../_static/mystnb.css" />
    <link rel="stylesheet" type="text/css" href="../_static/plot_directive.css" />
    
  <!-- Pre-loaded scripts that we'll load fully later -->
  <link rel="preload" as="script" href="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">

    <script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
    <script src="../_static/jquery.js"></script>
    <script src="../_static/underscore.js"></script>
    <script src="../_static/doctools.js"></script>
    <script>let toggleHintShow = 'Click to show';</script>
    <script>let toggleHintHide = 'Click to hide';</script>
    <script>let toggleOpenOnPrint = 'true';</script>
    <script src="../_static/togglebutton.js"></script>
    <script src="../_static/scripts/sphinx-book-theme.js?digest=f31d14ad54b65d19161ba51d4ffff3a77ae00456"></script>
    <script>var togglebuttonSelector = '.toggle, .admonition.dropdown, .tag_hide_input div.cell_input, .tag_hide-input div.cell_input, .tag_hide_output div.cell_output, .tag_hide-output div.cell_output, .tag_hide_cell.cell, .tag_hide-cell.cell';</script>
    <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script>
    <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <link rel="index" title="索引" href="../genindex.html" />
    <link rel="search" title="搜索" href="../search.html" />
    <link rel="prev" title="ViT解读" href="ViT%E8%A7%A3%E8%AF%BB.html" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <meta name="docsearch:language" content="zh">
    

    <!-- Google Analytics -->
    
  </head>
  <body data-spy="scroll" data-target="#bd-toc-nav" data-offset="60">
<!-- Checkboxes to toggle the left sidebar -->
<input type="checkbox" class="sidebar-toggle" name="__navigation" id="__navigation" aria-label="Toggle navigation sidebar">
<label class="overlay overlay-navbar" for="__navigation">
    <div class="visually-hidden">Toggle navigation sidebar</div>
</label>
<!-- Checkboxes to toggle the in-page toc -->
<input type="checkbox" class="sidebar-toggle" name="__page-toc" id="__page-toc" aria-label="Toggle in-page Table of Contents">
<label class="overlay overlay-pagetoc" for="__page-toc">
    <div class="visually-hidden">Toggle in-page Table of Contents</div>
</label>
<!-- Headers at the top -->
<div class="announcement header-item noprint"></div>
<div class="header header-item noprint"></div>

    
    <div class="container-fluid" id="banner"></div>

    

    <div class="container-xl">
      <div class="row">
          
<!-- Sidebar -->
<div class="bd-sidebar noprint" id="site-navigation">
    <div class="bd-sidebar__content">
        <div class="bd-sidebar__top"><div class="navbar-brand-box">
    <a class="navbar-brand text-wrap" href="../index.html">
      
      
      
      <h1 class="site-logo" id="site-title">深入浅出PyTorch</h1>
      
    </a>
</div><form class="bd-search d-flex align-items-center" action="../search.html" method="get">
  <i class="icon fas fa-search"></i>
  <input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main">
    <div class="bd-toc-item active">
        <p aria-level="2" class="caption" role="heading">
 <span class="caption-text">
  目录
 </span>
</p>
<ul class="current nav bd-sidenav">
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/index.html">
   第零章：前置知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/>
  <label for="toctree-checkbox-1">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.1%20%E4%BA%BA%E5%B7%A5%E6%99%BA%E8%83%BD%E7%AE%80%E5%8F%B2.html">
     人工智能简史
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.2%20%E8%AF%84%E4%BB%B7%E6%8C%87%E6%A0%87.html">
     模型评价指标
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.3%20%E5%B8%B8%E7%94%A8%E5%8C%85%E7%9A%84%E5%AD%A6%E4%B9%A0.html">
     常用包的学习
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E9%9B%B6%E7%AB%A0/0.4%20Jupyter%E7%9B%B8%E5%85%B3%E6%93%8D%E4%BD%9C.html">
     Jupyter notebook/Lab 简述
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/index.html">
   第一章：PyTorch的简介和安装
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/>
  <label for="toctree-checkbox-2">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.1%20PyTorch%E7%AE%80%E4%BB%8B.html">
     1.1 PyTorch简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.2%20PyTorch%E7%9A%84%E5%AE%89%E8%A3%85.html">
     1.2 PyTorch的安装
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%80%E7%AB%A0/1.3%20PyTorch%E7%9B%B8%E5%85%B3%E8%B5%84%E6%BA%90.html">
     1.3 PyTorch相关资源
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/index.html">
   第二章：PyTorch基础知识
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/>
  <label for="toctree-checkbox-3">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.1%20%E5%BC%A0%E9%87%8F.html">
     2.1 张量
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.2%20%E8%87%AA%E5%8A%A8%E6%B1%82%E5%AF%BC.html">
     2.2 自动求导
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.3%20%E5%B9%B6%E8%A1%8C%E8%AE%A1%E7%AE%97%E7%AE%80%E4%BB%8B.html">
     2.3 并行计算简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%8C%E7%AB%A0/2.4%20AI%E7%A1%AC%E4%BB%B6%E5%8A%A0%E9%80%9F%E8%AE%BE%E5%A4%87.html">
     AI硬件加速设备
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/index.html">
   第三章：PyTorch的主要组成模块
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/>
  <label for="toctree-checkbox-4">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.1%20%E6%80%9D%E8%80%83%EF%BC%9A%E5%AE%8C%E6%88%90%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E7%9A%84%E5%BF%85%E8%A6%81%E9%83%A8%E5%88%86.html">
     3.1 思考：完成深度学习的必要部分
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.2%20%E5%9F%BA%E6%9C%AC%E9%85%8D%E7%BD%AE.html">
     3.2 基本配置
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.3%20%E6%95%B0%E6%8D%AE%E8%AF%BB%E5%85%A5.html">
     3.3 数据读入
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.4%20%E6%A8%A1%E5%9E%8B%E6%9E%84%E5%BB%BA.html">
     3.4 模型构建
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.5%20%E6%A8%A1%E5%9E%8B%E5%88%9D%E5%A7%8B%E5%8C%96.html">
     3.5 模型初始化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.6%20%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     3.6 损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.7%20%E8%AE%AD%E7%BB%83%E4%B8%8E%E8%AF%84%E4%BC%B0.html">
     3.7 训练和评估
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.8%20%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     3.8 可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%89%E7%AB%A0/3.9%20%E4%BC%98%E5%8C%96%E5%99%A8.html">
     3.9 PyTorch优化器
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/index.html">
   第四章：PyTorch基础实战
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/>
  <label for="toctree-checkbox-5">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.1%20ResNet.html">
     4.1 ResNet
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%9B%9B%E7%AB%A0/4.4%20FashionMNIST%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     基础实战——FashionMNIST时装分类
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/index.html">
   第五章：PyTorch模型定义
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/>
  <label for="toctree-checkbox-6">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.1%20PyTorch%E6%A8%A1%E5%9E%8B%E5%AE%9A%E4%B9%89%E7%9A%84%E6%96%B9%E5%BC%8F.html">
     5.1 PyTorch模型定义的方式
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.2%20%E5%88%A9%E7%94%A8%E6%A8%A1%E5%9E%8B%E5%9D%97%E5%BF%AB%E9%80%9F%E6%90%AD%E5%BB%BA%E5%A4%8D%E6%9D%82%E7%BD%91%E7%BB%9C.html">
     5.2 利用模型块快速搭建复杂网络
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.3%20PyTorch%E4%BF%AE%E6%94%B9%E6%A8%A1%E5%9E%8B.html">
     5.3 PyTorch修改模型
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%BA%94%E7%AB%A0/5.4%20PyTorh%E6%A8%A1%E5%9E%8B%E4%BF%9D%E5%AD%98%E4%B8%8E%E8%AF%BB%E5%8F%96.html">
     5.4 PyTorch模型保存与读取
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/index.html">
   第六章：PyTorch进阶训练技巧
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/>
  <label for="toctree-checkbox-7">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.1%20%E8%87%AA%E5%AE%9A%E4%B9%89%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0.html">
     6.1 自定义损失函数
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.2%20%E5%8A%A8%E6%80%81%E8%B0%83%E6%95%B4%E5%AD%A6%E4%B9%A0%E7%8E%87.html">
     6.2 动态调整学习率
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-torchvision.html">
     6.3 模型微调-torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.3%20%E6%A8%A1%E5%9E%8B%E5%BE%AE%E8%B0%83-timm.html">
     6.3 模型微调 - timm
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.4%20%E5%8D%8A%E7%B2%BE%E5%BA%A6%E8%AE%AD%E7%BB%83.html">
     6.4 半精度训练
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.5%20%E6%95%B0%E6%8D%AE%E5%A2%9E%E5%BC%BA-imgaug.html">
     6.5 数据增强-imgaug
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AD%E7%AB%A0/6.6%20%E4%BD%BF%E7%94%A8argparse%E8%BF%9B%E8%A1%8C%E8%B0%83%E5%8F%82.html">
     6.6 使用argparse进行调参
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/index.html">
   第七章：PyTorch可视化
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/>
  <label for="toctree-checkbox-8">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.1%20%E5%8F%AF%E8%A7%86%E5%8C%96%E7%BD%91%E7%BB%9C%E7%BB%93%E6%9E%84.html">
     7.1 可视化网络结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.2%20CNN%E5%8D%B7%E7%A7%AF%E5%B1%82%E5%8F%AF%E8%A7%86%E5%8C%96.html">
     7.2 CNN可视化
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.3%20%E4%BD%BF%E7%94%A8TensorBoard%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.3 使用TensorBoard可视化训练过程
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B8%83%E7%AB%A0/7.4%20%E4%BD%BF%E7%94%A8wandb%E5%8F%AF%E8%A7%86%E5%8C%96%E8%AE%AD%E7%BB%83%E8%BF%87%E7%A8%8B.html">
     7.4 使用wandb可视化训练过程
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/index.html">
   第八章：PyTorch生态简介
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/>
  <label for="toctree-checkbox-9">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.1%20%E6%9C%AC%E7%AB%A0%E7%AE%80%E4%BB%8B.html">
     8.1 本章简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.2%20%E5%9B%BE%E5%83%8F%20-%20torchvision.html">
     8.2 torchvision
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.3%20%E8%A7%86%E9%A2%91%20-%20PyTorchVideo.html">
     8.3 PyTorchVideo简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.4%20%E6%96%87%E6%9C%AC%20-%20torchtext.html">
     8.4 torchtext简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E5%85%AB%E7%AB%A0/8.5%20%E9%9F%B3%E9%A2%91%20-%20torchaudio.html">
     8.5 torchaudio简介
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 has-children">
  <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/index.html">
   第九章：PyTorch的模型部署
  </a>
  <input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/>
  <label for="toctree-checkbox-10">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul>
   <li class="toctree-l2">
    <a class="reference internal" href="../%E7%AC%AC%E4%B9%9D%E7%AB%A0/9.1%20%E4%BD%BF%E7%94%A8ONNX%E8%BF%9B%E8%A1%8C%E9%83%A8%E7%BD%B2%E5%B9%B6%E6%8E%A8%E7%90%86.html">
     9.1 使用ONNX进行部署并推理
    </a>
   </li>
  </ul>
 </li>
 <li class="toctree-l1 current active has-children">
  <a class="reference internal" href="index.html">
   第十章：常见代码解读
  </a>
  <input checked="" class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/>
  <label for="toctree-checkbox-11">
   <i class="fas fa-chevron-down">
   </i>
  </label>
  <ul class="current">
   <li class="toctree-l2">
    <a class="reference internal" href="10.1%20%E5%9B%BE%E5%83%8F%E5%88%86%E7%B1%BB.html">
     10.1 图像分类简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.2%20%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B.html">
     目标检测简介
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="10.3%20%E5%9B%BE%E5%83%8F%E5%88%86%E5%89%B2.html">
     10.3 图像分割简介（补充中）
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ResNet%E6%BA%90%E7%A0%81%E8%A7%A3%E8%AF%BB.html">
     ResNet源码解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="RNN%E8%AF%A6%E8%A7%A3%E5%8F%8A%E5%85%B6%E5%AE%9E%E7%8E%B0.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="LSTM%E8%A7%A3%E8%AF%BB%E5%8F%8A%E5%AE%9E%E6%88%98.html">
     文章结构
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="Transformer%20%E8%A7%A3%E8%AF%BB.html">
     Transformer 解读
    </a>
   </li>
   <li class="toctree-l2">
    <a class="reference internal" href="ViT%E8%A7%A3%E8%AF%BB.html">
     ViT解读
    </a>
   </li>
   <li class="toctree-l2 current active">
    <a class="current reference internal" href="#">
     Swin Transformer解读
    </a>
   </li>
  </ul>
 </li>
</ul>

    </div>
</nav></div>
        <div class="bd-sidebar__bottom">
             <!-- To handle the deprecated key -->
            
            <div class="navbar_extra_footer">
            Theme by the <a href="https://ebp.jupyterbook.org">Executable Book Project</a>
            </div>
            
        </div>
    </div>
    <div id="rtd-footer-container"></div>
</div>


          


          
<!-- A tiny helper pixel to detect if we've scrolled -->
<div class="sbt-scroll-pixel-helper"></div>
<!-- Main content -->
<div class="col py-0 content-container">
    
    <div class="header-article row sticky-top noprint">
        



<div class="col py-1 d-flex header-article-main">
    <div class="header-article__left">
        
        <label for="__navigation"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="right"
title="Toggle navigation"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-bars"></i>
  </span>

</label>

        
    </div>
    <div class="header-article__right">
<button onclick="toggleFullScreen()"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="bottom"
title="Fullscreen mode"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-expand"></i>
  </span>

</button>

<div class="menu-dropdown menu-dropdown-repository-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Source repositories">
      <i class="fab fa-github"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Source repository"
>
  

<span class="headerbtn__icon-container">
  <i class="fab fa-github"></i>
  </span>
<span class="headerbtn__text-container">repository</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/issues/new?title=Issue%20on%20page%20%2F第十章/Swin-Transformer解读.html&body=Your%20issue%20content%20here."
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Open an issue"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-lightbulb"></i>
  </span>
<span class="headerbtn__text-container">open issue</span>
</a>

      </li>
      
      <li>
        <a href="https://github.com/datawhalechina/thorough-pytorch/edit/master/第十章/Swin-Transformer解读.md"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Edit this page"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-pencil-alt"></i>
  </span>
<span class="headerbtn__text-container">suggest edit</span>
</a>

      </li>
      
    </ul>
  </div>
</div>

<div class="menu-dropdown menu-dropdown-download-buttons">
  <button class="headerbtn menu-dropdown__trigger"
      aria-label="Download this page">
      <i class="fas fa-download"></i>
  </button>
  <div class="menu-dropdown__content">
    <ul>
      <li>
        <a href="../_sources/第十章/Swin-Transformer解读.md.txt"
   class="headerbtn"
   data-toggle="tooltip"
data-placement="left"
title="Download source file"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file"></i>
  </span>
<span class="headerbtn__text-container">.md</span>
</a>

      </li>
      
      <li>
        
<button onclick="printPdf(this)"
  class="headerbtn"
  data-toggle="tooltip"
data-placement="left"
title="Print to PDF"
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-file-pdf"></i>
  </span>
<span class="headerbtn__text-container">.pdf</span>
</button>

      </li>
      
    </ul>
  </div>
</div>
<label for="__page-toc"
  class="headerbtn headerbtn-page-toc"
  
>
  

<span class="headerbtn__icon-container">
  <i class="fas fa-list"></i>
  </span>

</label>

    </div>
</div>

<!-- Table of contents -->
<div class="col-md-3 bd-toc show noprint">
    <div class="tocsection onthispage pt-5 pb-3">
        <i class="fas fa-list"></i> Contents
    </div>
    <nav id="bd-toc-nav" aria-label="Page">
        <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   前言
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   模型结构
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#patch-embedding">
   Patch Embedding
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#patch-merging">
   Patch Merging
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#window-partition-reverse">
   Window Partition/Reverse
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#window-attention">
   Window Attention
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     相关位置编码的直观理解
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id4">
     相关位置编码的代码详解
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#shifted-window-attention">
   Shifted Window Attention
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id5">
     特征图移位操作
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#attention-mask">
     Attention Mask
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#w-msamsa">
   W-MSA和MSA的复杂度对比
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#msa">
     MSA模块的计算量
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#w-msa">
     W-MSA模块的计算量
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id6">
   整体流程图
  </a>
 </li>
</ul>

    </nav>
</div>
    </div>
    <div class="article row">
        <div class="col pl-md-3 pl-lg-5 content-container">
            <!-- Table of contents that is only displayed when printing the page -->
            <div id="jb-print-docs-body" class="onlyprint">
                <h1>Swin Transformer解读</h1>
                <!-- Table of contents -->
                <div id="print-main-content">
                    <div id="jb-print-toc">
                        
                        <div>
                            <h2> Contents </h2>
                        </div>
                        <nav aria-label="Page">
                            <ul class="visible nav section-nav flex-column">
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id1">
   前言
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id2">
   模型结构
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#patch-embedding">
   Patch Embedding
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#patch-merging">
   Patch Merging
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#window-partition-reverse">
   Window Partition/Reverse
  </a>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#window-attention">
   Window Attention
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id3">
     相关位置编码的直观理解
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id4">
     相关位置编码的代码详解
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#shifted-window-attention">
   Shifted Window Attention
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#id5">
     特征图移位操作
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#attention-mask">
     Attention Mask
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#w-msamsa">
   W-MSA和MSA的复杂度对比
  </a>
  <ul class="nav section-nav flex-column">
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#msa">
     MSA模块的计算量
    </a>
   </li>
   <li class="toc-h3 nav-item toc-entry">
    <a class="reference internal nav-link" href="#w-msa">
     W-MSA模块的计算量
    </a>
   </li>
  </ul>
 </li>
 <li class="toc-h2 nav-item toc-entry">
  <a class="reference internal nav-link" href="#id6">
   整体流程图
  </a>
 </li>
</ul>

                        </nav>
                    </div>
                </div>
            </div>
            <main id="main-content" role="main">
                
              <div>
                
  <section class="tex2jax_ignore mathjax_ignore" id="swin-transformer">
<h1>Swin Transformer解读<a class="headerlink" href="#swin-transformer" title="永久链接至标题">#</a></h1>
<p>
<font size=3><b>[Swin-T] Swin Transformer: Hierarchical Vision Transformer using Shifted Windows</b></font>
<br>
<font size=2>Ze Liu, Yutong Lin, Yue Cao, Han Hu, Yixuan Wei, Zheng Zhang, Stephen Lin, Baining Guo.</font>
<br>
<font size=2>ICCV 2021.</font>
<a href='https://arxiv.org/pdf/2103.14030.pdf'>[paper]</a> <a href='https://github.com/microsoft/Swin-Transformer'>[code]</a> 
<br>
<font size=3>解读者：沈豪，复旦大学博士，Datawhale成员</font>
<br>
</p>
<section id="id1">
<h2>前言<a class="headerlink" href="#id1" title="永久链接至标题">#</a></h2>
<p>《<a class="reference external" href="https://arxiv.org/abs/2103.14030">Swin Transformer: Hierarchical Vision Transformer using Shifted Windows</a>》作为2021 ICCV最佳论文，屠榜了各大CV任务，性能优于DeiT、ViT和EfficientNet等主干网络，已经替代经典的CNN架构，成为了<strong>计算机视觉领域通用的backbone</strong>。它基于了ViT模型的思想，创新性的引入了<strong>滑动窗口机制</strong>，让模型能够学习到跨窗口的信息，同时也。同时通过<strong>下采样层</strong>，使得模型能够处理超分辨率的图片，节省计算量以及能够关注全局和局部的信息。而本文将从原理和代码角度详细解析Swin Transformer的架构。</p>
<p>目前将 Transformer 从自然语言处理领域应用到计算机视觉领域主要有两大挑战：</p>
<ul class="simple">
<li><p>视觉实体的方差较大，例如同一个物体，拍摄角度不同，转化为二进制后的图片就会具有很大的差异。同时在不同场景下视觉 Transformer 性能未必很好。</p></li>
<li><p>图像分辨率高，像素点多，如果采用ViT模型，自注意力的计算量会与像素的平方成正比。</p></li>
</ul>
<p>针对上述两个问题，论文中提出了一种基于<strong>滑动窗口机制，具有层级设计（下采样层）</strong> 的 Swin Transformer。</p>
<p>其中<strong>滑窗操作</strong>包括<strong>不重叠的 local window，和重叠的 cross-window</strong>。将注意力计算限制在一个窗口（window size固定）中，<strong>一方面能引入 CNN 卷积操作的局部性，另一方面能大幅度节省计算量</strong>，它只和窗口数量成线性关系。通过<strong>下采样</strong>的层级设计，能够逐渐增大感受野，从而使得注意力机制也能够注意到<strong>全局</strong>的特征。</p>
<img src="./figures/Swin-T&ViT.png" alt="Swin-T&ViT" style="zoom:50%;" />
<p>在论文的最后，作者也通过大量的实验证明Swin Transformer相较于以前的SOTA模型均有提高，尤其是在ADE20K数据和COCO数据集上的表现。也证明了Swin Transformer可以作为一种通用骨干网络被使用。</p>
</section>
<section id="id2">
<h2>模型结构<a class="headerlink" href="#id2" title="永久链接至标题">#</a></h2>
<p><img alt="Architecture" src="../_images/Architecture.png" /></p>
<p>整个模型采取层次化的设计，一共包含 4 个 Stage，除第一个 stage 外，每个 stage 都会先通过 <strong>Patch Merging</strong> 层缩小输入特征图的分辨率，进行<strong>下采样操作</strong>，像 CNN 一样逐层扩大感受野，以便获取到全局的信息。</p>
<p>以论文的角度：</p>
<ul class="simple">
<li><p>在输入开始的时候，做了一个<code class="docutils literal notranslate"><span class="pre">Patch</span> <span class="pre">Partition</span></code>，即ViT中<code class="docutils literal notranslate"><span class="pre">Patch</span> <span class="pre">Embedding</span></code>操作，通过 <strong>Patch_size</strong> 为4的卷积层将图片切成一个个 <strong>Patch</strong> ，并嵌入到<code class="docutils literal notranslate"><span class="pre">Embedding</span></code>，将 <strong>embedding_size</strong> 转变为48（可以将 CV 中图片的<strong>通道数</strong>理解为NLP中token的<strong>词嵌入长度</strong>）。</p></li>
<li><p>随后在第一个Stage中，通过<code class="docutils literal notranslate"><span class="pre">Linear</span> <span class="pre">Embedding</span></code>调整通道数为C。</p></li>
<li><p>在每个 Stage 里（除第一个 Stage ），均由<code class="docutils literal notranslate"><span class="pre">Patch</span> <span class="pre">Merging</span></code>和多个<code class="docutils literal notranslate"><span class="pre">Swin</span> <span class="pre">Transformer</span> <span class="pre">Block</span></code>组成。</p></li>
<li><p>其中<code class="docutils literal notranslate"><span class="pre">Patch</span> <span class="pre">Merging</span></code>模块主要在每个 Stage 一开始降低图片分辨率，进行下采样的操作。</p></li>
<li><p>而<code class="docutils literal notranslate"><span class="pre">Swin</span> <span class="pre">Transformer</span> <span class="pre">Block</span></code>具体结构如右图所示，主要是<code class="docutils literal notranslate"><span class="pre">LayerNorm</span></code>，<code class="docutils literal notranslate"><span class="pre">Window</span> <span class="pre">Attention</span></code> ，<code class="docutils literal notranslate"><span class="pre">Shifted</span> <span class="pre">Window</span> <span class="pre">Attention</span></code>和<code class="docutils literal notranslate"><span class="pre">MLP</span></code>组成 。</p></li>
</ul>
<p>从代码的角度：</p>
<p>在微软亚洲研究院提供的代码中，是将<code class="docutils literal notranslate"><span class="pre">Patch</span> <span class="pre">Merging</span></code>作为每个 Stage  最后结束的操作，输入先进行<code class="docutils literal notranslate"><span class="pre">Swin</span> <span class="pre">Transformer</span> <span class="pre">Block</span></code>操作，再下采样。而<strong>最后一个 Stage 不需要进行下采样操作</strong>，之间通过后续的全连接层与 <strong>target label</strong> 计算损失。</p>
<p><img alt="code_Architecture" src="../_images/code_Architecture.png" /></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># window_size=7 </span>
<span class="c1"># input_batch_image.shape=[128,3,224,224]</span>
<span class="k">class</span> <span class="nc">SwinTransformer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="o">...</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="o">...</span>
        <span class="c1"># absolute position embedding</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ape</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">absolute_pos_embed</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_patches</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">))</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">pos_drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">p</span><span class="o">=</span><span class="n">drop_rate</span><span class="p">)</span>

        <span class="c1"># build layers</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">layers</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ModuleList</span><span class="p">()</span>
        <span class="k">for</span> <span class="n">i_layer</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span>
            <span class="n">layer</span> <span class="o">=</span> <span class="n">BasicLayer</span><span class="p">(</span><span class="o">...</span><span class="p">)</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">avgpool</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">AdaptiveAvgPool1d</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">head</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="k">if</span> <span class="n">num_classes</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="n">nn</span><span class="o">.</span><span class="n">Identity</span><span class="p">()</span>

    <span class="k">def</span> <span class="nf">forward_features</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_embed</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Patch Partition</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">ape</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">absolute_pos_embed</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_drop</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

        <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>  <span class="c1"># Batch_size Windows_num Channels</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">avgpool</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>  <span class="c1"># Batch_size Channels 1</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">forward_features</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">head</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># self.head =&gt; Linear(in=Channels,out=Classification_num)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
<p>其中有几个地方处理方法与 ViT 不同：</p>
<ul class="simple">
<li><p>ViT 在输入会给 embedding 进行位置编码。而 Swin-T 这里则是作为一个<strong>可选项</strong>（<code class="docutils literal notranslate"><span class="pre">self.ape</span></code>），Swin-T 是在计算 Attention 的时候做了一个<strong>相对位置编码</strong>，我认为这是这篇论文设计最巧妙的地方。</p></li>
<li><p>ViT 会单独加上一个可学习参数，作为分类的 token。而 Swin-T 则是<strong>直接做平均</strong>（avgpool），输出分类，有点类似 CNN 最后的全局平均池化层。</p></li>
</ul>
</section>
<section id="patch-embedding">
<h2>Patch Embedding<a class="headerlink" href="#patch-embedding" title="永久链接至标题">#</a></h2>
<p>在输入进 Block 前，我们需要将图片切成一个个 patch，然后嵌入向量。</p>
<p>具体做法是对原始图片裁成一个个 <code class="docutils literal notranslate"><span class="pre">window_size</span> <span class="pre">*</span> <span class="pre">window_size</span></code> 的窗口大小，然后进行嵌入。</p>
<p>这里可以通过二维卷积层，<strong>将 stride，kernel_size 设置为 window_size 大小</strong>。设定输出通道来确定嵌入向量的大小。最后将 H,W 维度展开，并移动到第一维度。</p>
<blockquote>
<div><p>论文中输出通道设置为48，但是代码中为96，以下我们均以代码为准。</p>
<p>Batch_size=128</p>
</div></blockquote>
<p><img alt="Patch_embedding" src="../_images/Patch_embedding.png" /></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="nn">nn</span>

<span class="k">class</span> <span class="nc">PatchEmbed</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">img_size</span><span class="o">=</span><span class="mi">224</span><span class="p">,</span> <span class="n">patch_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">in_chans</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">embed_dim</span><span class="o">=</span><span class="mi">96</span><span class="p">,</span> <span class="n">norm_layer</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="n">img_size</span> <span class="o">=</span> <span class="n">to_2tuple</span><span class="p">(</span><span class="n">img_size</span><span class="p">)</span> <span class="c1"># -&gt; (img_size, img_size)</span>
        <span class="n">patch_size</span> <span class="o">=</span> <span class="n">to_2tuple</span><span class="p">(</span><span class="n">patch_size</span><span class="p">)</span> <span class="c1"># -&gt; (patch_size, patch_size)</span>
        <span class="n">patches_resolution</span> <span class="o">=</span> <span class="p">[</span><span class="n">img_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">img_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">img_size</span> <span class="o">=</span> <span class="n">img_size</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">patches_resolution</span> <span class="o">=</span> <span class="n">patches_resolution</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span> <span class="o">=</span> <span class="n">patches_resolution</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">patches_resolution</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">in_chans</span> <span class="o">=</span> <span class="n">in_chans</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_chans</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="n">patch_size</span><span class="p">)</span>
        <span class="k">if</span> <span class="n">norm_layer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="kc">None</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="c1"># 假设采取默认参数，论文中embedding_size是96，但是代码中为48.我们以代码为准</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># 出来的是(N, 96, 224/4, 224/4) </span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="c1"># 把HW维展开，(N, 96, 56*56)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span>  <span class="c1"># 把通道维放到最后 (N, 56*56, 96)</span>
        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
            <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
</section>
<section id="patch-merging">
<h2>Patch Merging<a class="headerlink" href="#patch-merging" title="永久链接至标题">#</a></h2>
<p>该模块的作用是在每个 Stage 开始前做降采样，用于缩小分辨率，调整通道数进而形成层次化的设计，同时也能节省一定运算量。</p>
<blockquote>
<div><p>在 CNN 中，则是在每个 Stage 开始前用<code class="docutils literal notranslate"><span class="pre">stride=2</span></code>的卷积/池化层来降低分辨率。</p>
</div></blockquote>
<p>每次降采样是两倍，因此<strong>在行方向和列方向上，间隔 2 选取元素</strong>。</p>
<p>然后拼接在一起作为一整个张量，最后展开。<strong>此时通道维度会变成原先的 4 倍</strong>（因为 H,W 各缩小 2 倍），此时再通过一个<strong>全连接层再调整通道维度为原来的两倍</strong>。</p>
<p>下面是一个示意图（输入张量 N=1, H=W=8, C=1，不包含最后的全连接层调整）</p>
<p><img alt="Patch_Merging" src="../_images/Patch_Merging.png" /></p>
<p><img alt="Patch_Merging_dim" src="../_images/Patch_Merging_dim.png" /></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">PatchMerging</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_resolution</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">norm_layer</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">input_resolution</span> <span class="o">=</span> <span class="n">input_resolution</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dim</span> <span class="o">=</span> <span class="n">dim</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">dim</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">dim</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">norm_layer</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">dim</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span>
        <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        x: B, H*W, C</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_resolution</span>
        <span class="n">B</span><span class="p">,</span> <span class="n">L</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
        <span class="k">assert</span> <span class="n">L</span> <span class="o">==</span> <span class="n">H</span> <span class="o">*</span> <span class="n">W</span><span class="p">,</span> <span class="s2">&quot;input feature has wrong size&quot;</span>
        <span class="k">assert</span> <span class="n">H</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">W</span> <span class="o">%</span> <span class="mi">2</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;x size (</span><span class="si">{</span><span class="n">H</span><span class="si">}</span><span class="s2">*</span><span class="si">{</span><span class="n">W</span><span class="si">}</span><span class="s2">) are not even.&quot;</span>

        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>

        <span class="n">x0</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">:]</span>  <span class="c1"># B H/2 W/2 C</span>
        <span class="n">x1</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">:]</span>  <span class="c1"># B H/2 W/2 C</span>
        <span class="n">x2</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">:]</span>  <span class="c1"># B H/2 W/2 C</span>
        <span class="n">x3</span> <span class="o">=</span> <span class="n">x</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">::</span><span class="mi">2</span><span class="p">,</span> <span class="p">:]</span>  <span class="c1"># B H/2 W/2 C</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">([</span><span class="n">x0</span><span class="p">,</span> <span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">x3</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># B H/2 W/2 4*C</span>
        <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="n">C</span><span class="p">)</span>  <span class="c1"># B H/2*W/2 4*C</span>

        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
        <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
</section>
<section id="window-partition-reverse">
<h2>Window Partition/Reverse<a class="headerlink" href="#window-partition-reverse" title="永久链接至标题">#</a></h2>
<p><code class="docutils literal notranslate"><span class="pre">window</span> <span class="pre">partition</span></code>函数是用于对张量划分窗口，指定窗口大小。将原本的张量从 <code class="docutils literal notranslate"><span class="pre">N</span> <span class="pre">H</span> <span class="pre">W</span> <span class="pre">C</span></code>, 划分成 <code class="docutils literal notranslate"><span class="pre">num_windows*B,</span> <span class="pre">window_size,</span> <span class="pre">window_size,</span> <span class="pre">C</span></code>，其中 <code class="docutils literal notranslate"><span class="pre">num_windows</span> <span class="pre">=</span> <span class="pre">H*W</span> <span class="pre">/</span> <span class="pre">window_size*window_size</span></code>，即窗口的个数。而<code class="docutils literal notranslate"><span class="pre">window</span> <span class="pre">reverse</span></code>函数则是对应的逆过程。这两个函数会在后面的<code class="docutils literal notranslate"><span class="pre">Window</span> <span class="pre">Attention</span></code>用到。</p>
<p><img alt="Window_Partition_Reverse" src="../_images/Window_Partition_Reverse.png" /></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">window_partition</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">window_size</span><span class="p">):</span>
    <span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">H</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">W</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>
    <span class="n">windows</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">windows</span>

<span class="k">def</span> <span class="nf">window_reverse</span><span class="p">(</span><span class="n">windows</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">):</span>
    <span class="n">B</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">windows</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="p">(</span><span class="n">H</span> <span class="o">*</span> <span class="n">W</span> <span class="o">/</span> <span class="n">window_size</span> <span class="o">/</span> <span class="n">window_size</span><span class="p">))</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">windows</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">H</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">W</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
</section>
<section id="window-attention">
<h2>Window Attention<a class="headerlink" href="#window-attention" title="永久链接至标题">#</a></h2>
<p>传统的 Transformer 都是<strong>基于全局来计算注意力的</strong>，因此计算复杂度十分高。而 Swin Transformer 则将<strong>注意力的计算限制在每个窗口内</strong>，进而减少了计算量。我们先简单看下公式</p>
<div class="math notranslate nohighlight">
\[
Attention(Q,K,V)=Softmax(\frac{{QK}^T}{\sqrt d}+B)V
\]</div>
<p>主要区别是在原始计算 Attention 的公式中的 Q,K 时<strong>加入了相对位置编码</strong>。</p>
<img src="./figures/Swin-T_block.png" alt="Swin-T_block" style="zoom:50%;" />
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">WindowAttention</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="sa">r</span><span class="sd">&quot;&quot;&quot; Window based multi-head self attention (W-MSA) module with relative position bias.</span>
<span class="sd">    It supports both of shifted and non-shifted window.</span>

<span class="sd">    Args:</span>
<span class="sd">        dim (int): Number of input channels.</span>
<span class="sd">        window_size (tuple[int]): The height and width of the window.</span>
<span class="sd">        num_heads (int): Number of attention heads.</span>
<span class="sd">        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True</span>
<span class="sd">        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set</span>
<span class="sd">        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0</span>
<span class="sd">        proj_drop (float, optional): Dropout ratio of output. Default: 0.0</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">attn_drop</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span> <span class="n">proj_drop</span><span class="o">=</span><span class="mf">0.</span><span class="p">):</span>

        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dim</span> <span class="o">=</span> <span class="n">dim</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span> <span class="o">=</span> <span class="n">window_size</span>  <span class="c1"># Wh, Ww</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="c1"># nH</span>
        <span class="n">head_dim</span> <span class="o">=</span> <span class="n">dim</span> <span class="o">//</span> <span class="n">num_heads</span> <span class="c1"># 每个注意力头对应的通道数</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="n">qk_scale</span> <span class="ow">or</span> <span class="n">head_dim</span> <span class="o">**</span> <span class="o">-</span><span class="mf">0.5</span>

        <span class="c1"># define a parameter table of relative position bias</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">relative_position_bias_table</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Parameter</span><span class="p">(</span>
            <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">2</span> <span class="o">*</span> <span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="n">num_heads</span><span class="p">))</span>  <span class="c1"># 设置一个形状为（2*(Wh-1) * 2*(Ww-1), nH）的可学习变量，用于后续的位置编码</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="n">qkv_bias</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">attn_drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">attn_drop</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">dim</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">proj_drop</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">proj_drop</span><span class="p">)</span>

        <span class="n">trunc_normal_</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">relative_position_bias_table</span><span class="p">,</span> <span class="n">std</span><span class="o">=</span><span class="mf">.02</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="c1"># 相关位置编码...</span>
</pre></div>
</div>
<section id="id3">
<h3>相关位置编码的直观理解<a class="headerlink" href="#id3" title="永久链接至标题">#</a></h3>
<blockquote>
<div><p>Q,K,V.shape=[numWindwos*B, num_heads, window_size*window_size, head_dim]</p>
<ul class="simple">
<li><p>window_size*window_size 即 NLP 中<code class="docutils literal notranslate"><span class="pre">token</span></code>的个数</p></li>
<li><p><span class="math notranslate nohighlight">\(head\_dim=\frac{Embedding\_dim}{num\_heads}\)</span> 即 NLP 中<code class="docutils literal notranslate"><span class="pre">token</span></code>的词嵌入向量的维度</p></li>
</ul>
<p><span class="math notranslate nohighlight">\({QK}^T\)</span>计算出来的<code class="docutils literal notranslate"><span class="pre">Attention</span></code>张量的形状为<code class="docutils literal notranslate"><span class="pre">[numWindows*B,</span> <span class="pre">num_heads,</span> <span class="pre">Q_tokens,</span> <span class="pre">K_tokens]</span></code></p>
<ul class="simple">
<li><p>其中Q_tokens=K_tokens=window_size*window_size</p></li>
</ul>
</div></blockquote>
<p>以<code class="docutils literal notranslate"><span class="pre">window_size=2</span></code>为例：</p>
<img src="./figures/%E7%BB%9D%E5%AF%B9%E4%BD%8D%E7%BD%AE%E7%B4%A2%E5%BC%95.png" alt="绝对位置索引" style="zoom:33%;" />
<p>因此：<span class="math notranslate nohighlight">\({QK}^T=\left[\begin{array}{cccc}a_{11} &amp; a_{12} &amp; a_{13} &amp; a_{14} \\ a_{21} &amp; a_{22} &amp; a_{23} &amp; a_{24} \\ a_{31} &amp; a_{32} &amp; a_{33} &amp; a_{34} \\ a_{41} &amp; a_{42} &amp; a_{43} &amp; a_{44}\end{array}\right]\)</span></p>
<ul class="simple">
<li><p><strong>第 <span class="math notranslate nohighlight">\(i\)</span> 行表示第 <span class="math notranslate nohighlight">\(i\)</span> 个 token 的<code class="docutils literal notranslate"><span class="pre">query</span></code>对所有token的<code class="docutils literal notranslate"><span class="pre">key</span></code>的attention。</strong></p></li>
<li><p>对于 Attention 张量来说，<strong>以不同元素为原点，其他元素的坐标也是不同的</strong>，</p></li>
</ul>
<img src="./figures/%E7%9B%B8%E5%AF%B9%E4%BD%8D%E7%BD%AE%E7%B4%A2%E5%BC%95.png" alt="相对位置索引" style="zoom:50%;" />
<p>所以<span class="math notranslate nohighlight">\({QK}^T的相对位置索引=\left[\begin{array}{cccc}(0,0) &amp; (0,-1) &amp; (-1,0) &amp; (-1,-1) \\ (0,1) &amp; (0,0) &amp; (-1,1) &amp; (-1,0) \\ (1,0) &amp; (1,-1) &amp; (0,0) &amp; (0,-1) \\ (1,1) &amp; (1,0) &amp; (0,1) &amp; (0,0)\end{array}\right]\)</span></p>
<p>由于最终我们希望使用一维的位置坐标 <code class="docutils literal notranslate"><span class="pre">x+y</span></code> 代替二维的位置坐标<code class="docutils literal notranslate"> <span class="pre">(x,y)</span></code>，为了避免 (1,2) (2,1) 两个坐标转为一维时均为3，我们之后对相对位置索引进行了一些<strong>线性变换</strong>，使得能通过<strong>一维</strong>的位置坐标<strong>唯一映射</strong>到一个<strong>二维</strong>的位置坐标，详细可以通过代码部分进行理解。</p>
</section>
<section id="id4">
<h3>相关位置编码的代码详解<a class="headerlink" href="#id4" title="永久链接至标题">#</a></h3>
<p>首先我们利用<code class="docutils literal notranslate"><span class="pre">torch.arange</span></code>和<code class="docutils literal notranslate"><span class="pre">torch.meshgrid</span></code>函数生成对应的坐标，这里我们以<code class="docutils literal notranslate"><span class="pre">windowsize=2</span></code>为例子</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">coords_h</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">coords_w</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">coords</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">([</span><span class="n">coords_h</span><span class="p">,</span> <span class="n">coords_w</span><span class="p">])</span> <span class="c1"># -&gt; 2*(wh, ww)</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd">  (tensor([[0, 0],</span>
<span class="sd">           [1, 1]]), </span>
<span class="sd">   tensor([[0, 1],</span>
<span class="sd">           [0, 1]]))</span>
<span class="sd">&quot;&quot;&quot;</span>
</pre></div>
</div>
<p>然后堆叠起来，展开为一个二维向量</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">coords</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">coords</span><span class="p">)</span>  <span class="c1"># 2, Wh, Ww</span>
<span class="n">coords_flatten</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">coords</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>  <span class="c1"># 2, Wh*Ww</span>
<span class="sd">&quot;&quot;&quot;</span>
<span class="sd">tensor([[0, 0, 1, 1],</span>
<span class="sd">        [0, 1, 0, 1]])</span>
<span class="sd">&quot;&quot;&quot;</span>
</pre></div>
</div>
<p>利用广播机制，分别在第一维，第二维，插入一个维度，进行广播相减，得到 <code class="docutils literal notranslate"><span class="pre">2,</span> <span class="pre">wh*ww,</span> <span class="pre">wh*ww</span></code>的张量</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">relative_coords_first</span> <span class="o">=</span> <span class="n">coords_flatten</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]</span>  <span class="c1"># 2, wh*ww, 1</span>
<span class="n">relative_coords_second</span> <span class="o">=</span> <span class="n">coords_flatten</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="c1"># 2, 1, wh*ww</span>
<span class="n">relative_coords</span> <span class="o">=</span> <span class="n">relative_coords_first</span> <span class="o">-</span> <span class="n">relative_coords_second</span> <span class="c1"># 最终得到 2, wh*ww, wh*ww 形状的张量</span>
</pre></div>
</div>
<p><img alt="relative_pos_code" src="../_images/relative_pos_code.png" /></p>
<p>因为采取的是相减，所以得到的索引是从负数开始的，<strong>我们加上偏移量，让其从 0 开始</strong>。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">relative_coords</span> <span class="o">=</span> <span class="n">relative_coords</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span> <span class="c1"># Wh*Ww, Wh*Ww, 2</span>
<span class="n">relative_coords</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>
<span class="n">relative_coords</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>
</pre></div>
</div>
<p>后续我们需要将其展开成一维偏移量。而对于 (1，2）和（2，1）这两个坐标。在二维上是不同的，<strong>但是通过将 x,y 坐标相加转换为一维偏移的时候，他的偏移量是相等的</strong>。</p>
<p><img alt="bias0" src="../_images/bias0.png" /></p>
<p>所以最后我们对其中做了个乘法操作，以进行区分</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">relative_coords</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*=</span> <span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span>
</pre></div>
</div>
<p><img alt="offset multiply" src="../_images/coords.png" /></p>
<p>然后再最后一维上进行求和，展开成一个一维坐标，并注册为一个不参与网络学习的变量</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">relative_position_index</span> <span class="o">=</span> <span class="n">relative_coords</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># Wh*Ww, Wh*Ww</span>
<span class="bp">self</span><span class="o">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s2">&quot;relative_position_index&quot;</span><span class="p">,</span> <span class="n">relative_position_index</span><span class="p">)</span>
</pre></div>
</div>
<p>之前计算的是相对位置索引，并不是相对位置偏置参数。真正使用到的可训练参数<span class="math notranslate nohighlight">\(\hat B\)</span>是保存在<code class="docutils literal notranslate"><span class="pre">relative</span> <span class="pre">position</span> <span class="pre">bias</span> <span class="pre">table</span></code>表里的，这个表的长度是等于 <strong>(2M−1) × (2M−1)</strong> (在二维位置坐标中线性变化乘以2M-1导致)的。那么上述公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查<code class="docutils literal notranslate"><span class="pre">relative</span> <span class="pre">position</span> <span class="pre">bias</span> <span class="pre">table</span></code>表得到的。</p>
<img src="./figures/relative_pos_bias_table.png" alt="relative_pos_bias_table" style="zoom:50%;" />
<p>接着我们看前向代码</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
    <span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Args:</span>
<span class="sd">        x: input features with shape of (num_windows*B, N, C)</span>
<span class="sd">        mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="n">B_</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span>

    <span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B_</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">C</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">)</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span>
    <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">qkv</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>  <span class="c1"># make torchscript happy (cannot use tensor as tuple)</span>

    <span class="n">q</span> <span class="o">=</span> <span class="n">q</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span>
    <span class="n">attn</span> <span class="o">=</span> <span class="p">(</span><span class="n">q</span> <span class="o">@</span> <span class="n">k</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span>

    <span class="n">relative_position_bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">relative_position_bias_table</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">relative_position_index</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># Wh*Ww,Wh*Ww,nH</span>
    <span class="n">relative_position_bias</span> <span class="o">=</span> <span class="n">relative_position_bias</span><span class="o">.</span><span class="n">permute</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">contiguous</span><span class="p">()</span>  <span class="c1"># nH, Wh*Ww, Wh*Ww</span>
    <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">+</span> <span class="n">relative_position_bias</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># (1, num_heads, windowsize, windowsize)</span>

    <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># 下文会分析到</span>
        <span class="o">...</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>

    <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_drop</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>

    <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">attn</span> <span class="o">@</span> <span class="n">v</span><span class="p">)</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">B_</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj_drop</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
    <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
<ul class="simple">
<li><p>首先输入张量形状为 <code class="docutils literal notranslate"><span class="pre">[numWindows*B,</span> <span class="pre">window_size</span> <span class="pre">*</span> <span class="pre">window_size,</span> <span class="pre">C]</span></code></p></li>
<li><p>然后经过<code class="docutils literal notranslate"><span class="pre">self.qkv</span></code>这个全连接层后，进行 reshape，调整轴的顺序，得到形状为<code class="docutils literal notranslate"><span class="pre">[3,</span> <span class="pre">numWindows*B,</span> <span class="pre">num_heads,</span> <span class="pre">window_size*window_size,</span> <span class="pre">c//num_heads]</span></code>，并分配给<code class="docutils literal notranslate"><span class="pre">q,k,v</span></code>。</p></li>
<li><p>根据公式，我们对<code class="docutils literal notranslate"><span class="pre">q</span></code>乘以一个<code class="docutils literal notranslate"><span class="pre">scale</span></code>缩放系数，然后与<code class="docutils literal notranslate"><span class="pre">k</span></code>（为了满足矩阵乘要求，需要将最后两个维度调换）进行相乘。得到形状为<code class="docutils literal notranslate"><span class="pre">[numWindows*B,</span> <span class="pre">num_heads,</span> <span class="pre">window_size*window_size,</span> <span class="pre">window_size*window_size]</span></code>的<code class="docutils literal notranslate"><span class="pre">attn</span></code>张量</p></li>
<li><p>之前我们针对位置编码设置了个形状为<code class="docutils literal notranslate"><span class="pre">(2*window_size-1*2*window_size-1,</span> <span class="pre">numHeads)</span></code>的可学习变量。我们用计算得到的相对编码位置索引<code class="docutils literal notranslate"><span class="pre">self.relative_position_index.vew(-1)</span></code>选取，得到形状为<code class="docutils literal notranslate"><span class="pre">(window_size*window_size,</span> <span class="pre">window_size*window_size,</span> <span class="pre">numHeads)</span></code>的编码，再permute(2,0,1)后加到<code class="docutils literal notranslate"><span class="pre">attn</span></code>张量上</p></li>
<li><p>暂不考虑 mask 的情况，剩下就是跟 transformer 一样的 softmax，dropout，与<code class="docutils literal notranslate"><span class="pre">V</span></code>矩阵乘，再经过一层全连接层和 dropout</p></li>
</ul>
</section>
</section>
<section id="shifted-window-attention">
<h2>Shifted Window Attention<a class="headerlink" href="#shifted-window-attention" title="永久链接至标题">#</a></h2>
<p>前面的 Window Attention 是在每个窗口下计算注意力的，为了更好的和其他 window 进行信息交互，Swin Transformer 还引入了 shifted window 操作。</p>
<img src="./figures/Shifted_Window.png" alt="Shifted_Window" style="zoom:67%;" />
<p>左边是没有重叠的 Window Attention，而右边则是将窗口进行移位的 Shift Window Attention。可以看到移位后的窗口包含了原本相邻窗口的元素。但这也引入了一个新问题，即 <strong>window 的个数翻倍了</strong>，由原本四个窗口变成了 9 个窗口。在实际代码里，我们是<strong>通过对特征图移位，并给 Attention 设置 mask 来间接实现的</strong>。能在<strong>保持原有的 window 个数下</strong>，最后的计算结果等价。</p>
<p align=center><img src="./figures/W-MSA.png" alt="W-MSA" style="zoom:50%;" /></p>
<section id="id5">
<h3>特征图移位操作<a class="headerlink" href="#id5" title="永久链接至标题">#</a></h3>
<p>代码里对特征图移位是通过<code class="docutils literal notranslate"><span class="pre">torch.roll</span></code>来实现的，下面是示意图</p>
<img src="./figures/torch_roll.png" alt="torch_roll" style="zoom:67%;" />
<blockquote>
<div><p>如果需要<code class="docutils literal notranslate"><span class="pre">reverse</span> <span class="pre">cyclic</span> <span class="pre">shift</span></code>的话只需把参数<code class="docutils literal notranslate"><span class="pre">shifts</span></code>设置为对应的正数值。</p>
</div></blockquote>
</section>
<section id="attention-mask">
<h3>Attention Mask<a class="headerlink" href="#attention-mask" title="永久链接至标题">#</a></h3>
<p>这是 Swin Transformer 的精华，通过设置合理的 mask，让<code class="docutils literal notranslate"><span class="pre">Shifted</span> <span class="pre">Window</span> <span class="pre">Attention</span></code>在与<code class="docutils literal notranslate"><span class="pre">Window</span> <span class="pre">Attention</span></code>相同的窗口个数下，达到等价的计算结果。</p>
<p>首先我们对 Shift Window 后的每个窗口都给上 index，并且做一个<code class="docutils literal notranslate"><span class="pre">roll</span></code>操作（window_size=2, shift_size=1）</p>
<img src="./figures/Shift_window_index.png" alt="Shift window index" style="zoom:67%;" />
<p>我们希望在计算 Attention 的时候，<strong>让具有相同 index QK 进行计算，而忽略不同 index QK 计算结果</strong>。最后正确的结果如下图所示</p>
<img src="./figures/Mask.png" alt="Mask" style="zoom:50%;" />
<p>而要想在原始四个窗口下得到正确的结果，我们就必须给 Attention 的结果加入一个 mask（如上图最右边所示）相关代码如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shift_size</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
    <span class="c1"># calculate attention mask for SW-MSA</span>
    <span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_resolution</span>
    <span class="n">img_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span>  <span class="c1"># 1 H W 1</span>
    <span class="n">h_slices</span> <span class="o">=</span> <span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">),</span>
                <span class="nb">slice</span><span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">,</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">shift_size</span><span class="p">),</span>
                <span class="nb">slice</span><span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">shift_size</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span>
    <span class="n">w_slices</span> <span class="o">=</span> <span class="p">(</span><span class="nb">slice</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">),</span>
                <span class="nb">slice</span><span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">,</span> <span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">shift_size</span><span class="p">),</span>
                <span class="nb">slice</span><span class="p">(</span><span class="o">-</span><span class="bp">self</span><span class="o">.</span><span class="n">shift_size</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span>
    <span class="n">cnt</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">h_slices</span><span class="p">:</span>
        <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">w_slices</span><span class="p">:</span>
            <span class="n">img_mask</span><span class="p">[:,</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">cnt</span>
            <span class="n">cnt</span> <span class="o">+=</span> <span class="mi">1</span>

    <span class="n">mask_windows</span> <span class="o">=</span> <span class="n">window_partition</span><span class="p">(</span><span class="n">img_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">)</span>  <span class="c1"># nW, window_size, window_size, 1</span>
    <span class="n">mask_windows</span> <span class="o">=</span> <span class="n">mask_windows</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">)</span>
    <span class="n">attn_mask</span> <span class="o">=</span> <span class="n">mask_windows</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="o">-</span> <span class="n">mask_windows</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span>
    <span class="n">attn_mask</span> <span class="o">=</span> <span class="n">attn_mask</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">attn_mask</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="o">-</span><span class="mf">100.0</span><span class="p">))</span><span class="o">.</span><span class="n">masked_fill</span><span class="p">(</span><span class="n">attn_mask</span> <span class="o">==</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="mf">0.0</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
	<span class="n">attn_mask</span> <span class="o">=</span> <span class="kc">None</span>
</pre></div>
</div>
<p>以上图的设置，我们用这段代码会得到这样的一个 mask</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">tensor</span><span class="p">([[[[[</span>   <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">],</span>
           <span class="p">[</span>   <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">],</span>
           <span class="p">[</span>   <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">],</span>
           <span class="p">[</span>   <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">]]],</span>


         <span class="p">[[[</span>   <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">],</span>
           <span class="p">[</span><span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">],</span>
           <span class="p">[</span>   <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">],</span>
           <span class="p">[</span><span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">]]],</span>


         <span class="p">[[[</span>   <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">],</span>
           <span class="p">[</span>   <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">],</span>
           <span class="p">[</span><span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">],</span>
           <span class="p">[</span><span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">]]],</span>


         <span class="p">[[[</span>   <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">],</span>
           <span class="p">[</span><span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">],</span>
           <span class="p">[</span><span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">],</span>
           <span class="p">[</span><span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span> <span class="o">-</span><span class="mf">100.</span><span class="p">,</span>    <span class="mf">0.</span><span class="p">]]]]])</span>
</pre></div>
</div>
<p>在之前的 window attention 模块的前向代码里，包含这么一段</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    <span class="n">nW</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># 一张图被分为多少个windows eg:[4,49,49]</span>
    <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="n">B_</span> <span class="o">//</span> <span class="n">nW</span><span class="p">,</span> <span class="n">nW</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span> <span class="o">+</span> <span class="n">mask</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">unsqueeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># torch.Size([128, 4, 12, 49, 49]) torch.Size([1, 4, 1, 49, 49])</span>
    <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">N</span><span class="p">)</span>
    <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
    <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span>
</pre></div>
</div>
<p>将 mask 加到 attention 的计算结果，并进行 softmax。mask 的值设置为 - 100，softmax 后就会忽略掉对应的值。关于Mask，我们发现在官方代码库中的issue38也进行了讨论:--&gt;<a class="reference external" href="https://github.com/microsoft/Swin-Transformer/issues/38#issuecomment-823806591">The Question about the mask of window attention #38</a></p>
</section>
</section>
<section id="w-msamsa">
<h2>W-MSA和MSA的复杂度对比<a class="headerlink" href="#w-msamsa" title="永久链接至标题">#</a></h2>
<p>在原论文中，作者提出的基于<strong>滑动窗口操作</strong>的 <code class="docutils literal notranslate"><span class="pre">W-MSA</span></code> 能大幅度减少计算量。那么两者的计算量和算法复杂度大概是如何的呢，论文中给出了一下两个公式进行对比。
$<span class="math notranslate nohighlight">\(
\begin{aligned}
&amp;\Omega(M S A)=4 h w C^{2}+2(h w)^{2} C \\
&amp;\Omega(W-M S A)=4 h w C^{2}+2 M^{2} h w C
\end{aligned}
\)</span>$</p>
<ul class="simple">
<li><p><strong>h</strong>：feature map的高度</p></li>
<li><p><strong>w</strong>：feature map的宽度</p></li>
<li><p><strong>C</strong>：feature map的通道数（也可以称为embedding size的大小）</p></li>
<li><p><strong>M</strong>：window_size的大小</p></li>
</ul>
<section id="msa">
<h3>MSA模块的计算量<a class="headerlink" href="#msa" title="永久链接至标题">#</a></h3>
<p>首先对于<code class="docutils literal notranslate"><span class="pre">feature</span> <span class="pre">map</span></code>中每一个<code class="docutils literal notranslate"><span class="pre">token</span></code>（一共有 <span class="math notranslate nohighlight">\(hw\)</span> 个token，通道数为C），记作<span class="math notranslate nohighlight">\(X^{h w \times C}\)</span>，需要通过三次线性变换 <span class="math notranslate nohighlight">\(W_q,W_k,W_v\)</span> ，产生对应的<code class="docutils literal notranslate"><span class="pre">q,k,v</span></code>向量，记作 <span class="math notranslate nohighlight">\(Q^{h w \times C},K^{h w \times C},V^{h w \times C}\)</span> （通道数为C）。
$<span class="math notranslate nohighlight">\(
X^{h w \times C} \cdot W_{q}^{C \times C}=Q^{h w \times C} \\
X^{h w \times C} \cdot W_{k}^{C \times C}=K^{h w \times C} \\
X^{h w \times C} \cdot W_{v}^{C \times C}=V^{h w \times C} \\
\)</span><span class="math notranslate nohighlight">\(
根据矩阵运算的计算量公式可以得到运算量为 \)</span>3hwC \times C<span class="math notranslate nohighlight">\( ，即为 \)</span>3hwC^2<span class="math notranslate nohighlight">\( 。
\)</span><span class="math notranslate nohighlight">\(
Q^{h w \times C} \cdot K^T=A^{h w \times hw} \\
\Lambda^{h w \times h w}=Softmax(\frac{A^{h w \times hw}}{\sqrt(d)}+B) \\
\Lambda^{h w \times h w} \cdot V^{h w \times C}=Y^{h w \times C}
\)</span><span class="math notranslate nohighlight">\(
忽略除以\)</span>\sqrt d<span class="math notranslate nohighlight">\( 以及softmax的计算量，根据根据矩阵运算的计算量公式可得 \)</span>hwC \times hw + hw^2 \times C<span class="math notranslate nohighlight">\(  ，即为 \)</span>2(hw^2)C<span class="math notranslate nohighlight">\( 。
\)</span><span class="math notranslate nohighlight">\(
Y^{h w \times C} \cdot W_O^{C \times C}=O^{h w \times C}
\)</span><span class="math notranslate nohighlight">\(
最终再通过一个Linear层输出，计算量为 \)</span>hwC^2<span class="math notranslate nohighlight">\( 。因此整体的计算量为 \)</span>4 h w C^{2}+2(h w)^{2} C$​ 。</p>
</section>
<section id="w-msa">
<h3>W-MSA模块的计算量<a class="headerlink" href="#w-msa" title="永久链接至标题">#</a></h3>
<p>对于W-MSA模块，首先会将<code class="docutils literal notranslate"><span class="pre">feature</span> <span class="pre">map</span></code>根据<code class="docutils literal notranslate"><span class="pre">window_size</span></code>分成 <span class="math notranslate nohighlight">\(\frac{hw}{M^2}\)</span> 的窗口，每个窗口的宽高均为<span class="math notranslate nohighlight">\(M\)</span>，然后在每个窗口进行MSA的运算。因此，可以利用上面MSA的计算量公式，将 <span class="math notranslate nohighlight">\(h=M，w=M\)</span> 带入，可以得到一个窗口的计算量为 <span class="math notranslate nohighlight">\(4 M^2 C^{2}+2M^{4} C\)</span>  。</p>
<p>又因为有 <span class="math notranslate nohighlight">\(\frac{hw}{M^2}\)</span> 个窗口，则：
$<span class="math notranslate nohighlight">\(
\frac{hw}{M^2} \times\left(4M^2 C^2+2M^{4} C\right)=4 h w C^{2}+2 M^{2} h w C
\)</span><span class="math notranslate nohighlight">\(
假设`feature map`的\)</span>h=w=112，M=7，C=128<span class="math notranslate nohighlight">\(，采用W-MSA模块会比MSA模块节省约40124743680 FLOPs：
\)</span><span class="math notranslate nohighlight">\(
2(h w)^{2} C-2 M^{2} h w C=2 \times 112^{4} \times 128-2 \times 7^{2} \times 112^{2} \times 128=40124743680
\)</span>$</p>
</section>
</section>
<section id="id6">
<h2>整体流程图<a class="headerlink" href="#id6" title="永久链接至标题">#</a></h2>
<p align=center><img src="./figures/Swin-T.png" alt="Swin-T" style="zoom:80%;" /></p>
<p align=center><img src="./figures/Net.png" alt="Network" style="zoom:50%;" /></p>
<p><img alt="Hyper_parameters" src="../_images/Hyper_parameters.png" /></p>
<blockquote>
<div><p>参考博客：</p>
<p>https://zhuanlan.zhihu.com/p/367111046</p>
</div></blockquote>
<blockquote>
<div><p>联系方式：</p>
<ul class="simple">
<li><p>个人知乎：https://www.zhihu.com/people/shenhao-63</p></li>
<li><p>Github：https://github.com/shenhao-stu</p></li>
</ul>
</div></blockquote>
</section>
</section>


              </div>
              
            </main>
            <footer class="footer-article noprint">
                
    <!-- Previous / next buttons -->
<div class='prev-next-area'>
    <a class='left-prev' id="prev-link" href="ViT%E8%A7%A3%E8%AF%BB.html" title="上一页 页">
        <i class="fas fa-angle-left"></i>
        <div class="prev-next-info">
            <p class="prev-next-subtitle">上一页</p>
            <p class="prev-next-title">ViT解读</p>
        </div>
    </a>
</div>
            </footer>
        </div>
    </div>
    <div class="footer-content row">
        <footer class="col footer"><p>
  
    By ZhikangNiu<br/>
  
      &copy; Copyright 2022, ZhikangNiu.<br/>
</p>
        </footer>
    </div>
    
</div>


      </div>
    </div>
  
  <!-- Scripts loaded after <body> so the DOM is not blocked -->
  <script src="../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>


  </body>
</html>